import torch
from torch.distributions import Normal
from torch import nn
from torch.nn import ModuleList


def init_leaky(m):
    if m == nn.Linear:
        nn.init.xavier_uniform_(m.weight,gain = nn.init.calculate_gain('leaky_relu'))
        nn.init.constant_(m.bias, 0.01)

# initializtion tanh
def init_tanh(m):
    if m == nn.Linear:
        nn.init.xavier_uniform_(m.weight,gain = nn.init.calculate_gain('tanh'))
        nn.init.constant_(m.bias, 0.01)

# initializtion linear
def init_lin(m):
    if m == nn.Linear:
        nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('linear'))
        nn.init.constant_(m.bias, 0.01)

# initializtion relu
def init_relu(m):
    if m == nn.Linear:
        nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
        nn.init.constant_(m.bias, 0.01)

def init_default(m):
    if m == nn.Linear:
        nn.init.xavier_uniform_(m.weight)
        nn.init.constant_(m.bias, 0.01)

class Actor_SAC(nn.Module):
    def __init__(self, mlp, n_in, n_out, action_space,args):
        super().__init__()
        self.mlp=mlp
        device = args.device
        self.action_scale = torch.FloatTensor((action_space.high - action_space.low) / 2.).to(device)
        self.action_bias = torch.FloatTensor((action_space.high + action_space.low) / 2.).to(device)
        self.dist = Normal
        self.mu_gen = nn.Linear(in_features=self.mlp.num_output, out_features=n_out)
        self.mu_gen.apply(init_lin)
        self.std_gen = nn.Linear(in_features=self.mlp.num_output, out_features=n_out)
        self.std_gen.apply(init_lin)
        self.LOG_SIG_MAX = 2
        self.LOG_SIG_MIN = -20

    def bound_action(self,action,action_logprob,dist):
        action2 = torch.tanh(action)
        action=action2 * self.action_scale + self.action_bias

        #https://arxiv.org/pdf/1812.05905.pdf
        action_logprob -= (torch.log(self.action_scale*(1 - action2.pow(2)) + 1e-6))
        action_logprob = action_logprob.sum(dim=1)
        return action,action_logprob.view(-1,1),dist

    def bound_action_old(self,action,action_logprob,dist):
        action2 = torch.tanh(action)
        action=action2 * self.action_scale + self.action_bias

        #https://arxiv.org/pdf/1812.05905.pdf
        action_logprob -= (torch.log(self.action_scale*(1 - action2.pow(2)) + 1e-6)).sum(dim=1)
        return action,action_logprob.view(-1,1),dist

    def forward(self,inputs,deterministic=False,**kwargs):
        actor_features = self.mlp(inputs,**kwargs)
        mu = self.mu_gen(actor_features)
        std = self.std_gen(actor_features)
        # scale = torch.exp(torch.clamp(std, min=self.LOG_SIG_MIN, max=self.LOG_SIG_MAX))
        scale = torch.exp(std)
        dist = self.dist(mu, scale)
        naction = dist.rsample() if not(deterministic) else mu#torch.tanh(mu) * self.action_scale + self.action_bias
        naction_log_probs = dist.log_prob(naction)#.sum(dim=1)
        action, action_log_probs, dist = self.bound_action(naction, naction_log_probs,  dist)
        return action,action_log_probs,dist


class MLP(nn.Module):
    def __init__(self, num_inputs, num_output,hidden_size=64,activation=nn.Tanh,num_layers=1,args=None,
                 last_linear=True,dropout=0,init_zero=None,batch_norm=False,goal_size=None):
        super(MLP,self).__init__()
        assert num_layers >= 0, "cant have less than 0 hidden_ layers"
        self.num_output=num_output
        self.num_inputs=num_inputs
        if activation == nn.Tanh:
            activ_fc=init_tanh
        elif activation == nn.LeakyReLU:
            activ_fc=init_leaky
        elif activation == nn.ReLU:
            activ_fc=init_relu
        else:
            activ_fc = init_default
        self.activ_fc = activ_fc


        if num_layers == 0:
            assert last_linear, "can not have 0 hidden layers and no last linear layer"
            assert init_zero is None, "can currently handle"
            self.model = nn.Sequential(nn.Linear(num_inputs+(goal_size if goal_size is not None else 0),num_output))
        else:
            if not (last_linear):num_layers = num_layers - 1
            layers = []
            layers.append(nn.Linear(num_inputs, hidden_size))
            if batch_norm:
                layers.append(nn.BatchNorm1d(hidden_size))
            layers.append(activation())
            if dropout > 0:
                layers.append(nn.Dropout(p=dropout))

            for i in range(num_layers-1):
                layers.append(nn.Linear(hidden_size+(goal_size if goal_size is not None else 0), hidden_size))
                if batch_norm:
                    layers.append(nn.BatchNorm1d(hidden_size))
                layers.append(activation())
                if dropout > 0:
                    layers.append(nn.Dropout(p=dropout))
            layers.append(nn.Linear(hidden_size+(goal_size if goal_size is not None else 0), num_output))
            if not (last_linear):
                if batch_norm :
                    layers.append(nn.BatchNorm1d(hidden_size))
                layers.append(activation())
                if dropout > 0:
                    layers.append(nn.Dropout(p=dropout))
            self.model = ModuleList(layers).apply(activ_fc)

            if last_linear:
                layers[-1].apply(init_lin)

        self.train()

    def get_gradient(self):
        return torch.norm(torch.stack([torch.norm(p.grad.detach(),1) for p in self.parameters()]),1)

    def forward(self, inputs):
        out=inputs
        for i,layer in enumerate(self.model):
            out = layer(out)
        return out


class SkewfitConvNet(nn.Module):
    def __init__(self, num_output,args):
        super().__init__()
        self.num_output = num_output
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=3, padding=0), #48-5 / 3 = 43/3 + 1 = 14 + 1 15 | padding: 50 - 5 / 3 = 15 +1 = 16
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=0),# 15-3 / 2 = 6 + 1 = 7
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=0),# 2
            # nn.ReLU(),
            nn.Flatten()
        )
        self.model.apply(init_relu)
        self.last_fc=nn.Linear(576, num_output)
        self.last_fc.apply(init_lin)

        self.args=args
        self.train()

    def get_gradient(self):
        return torch.norm(torch.stack([torch.norm(p.grad.detach(),1) for p in self.parameters()]),1)

    def forward(self, inputs,store=False):
        features = self.model(inputs)
        res_final = self.last_fc(features)
        return res_final
